"""
This is the main preprocessing file.
1. Read the raw HYCOM data. 
2. Move HYCOM data Curvilinear grid to 4 km rectilinear grid
3. Apply spatial filter to Total SSH 
4. Move the Filtered SSH along with Total & Steric SSH from 4 km grid to altimeter track points
5. Save the HYCOM data on altimeter tracks - hourly data - daily nc files


Author: Badarvada Yadidya
Created on: August 2023

"""

import xarray as xr
import sys
import numpy as np
import pandas as pd
import os
import pyinterp
import utide as ut
from scipy.spatial import cKDTree
import scipy.ndimage as ndimage

grid_info = xr.open_dataset('/nobackup/ybadarva/hycom/grid_info.nc')
nlyss_info = xr.open_dataset("nlyss_info.nc")

def gaussian_filter(data, sigma):
    return ndimage.gaussian_filter(data, sigma)

def compute_sigma(lambda_val, res):
    sigma = (lambda_val/res) / (2*np.pi)
    return sigma
sigma = compute_sigma(300,4)
# to extract diurnal signal 600 is used instead of 300 to compute sigma.

resolution = 0.036    # 4 km

mx_lon = np.arange(0,360, resolution)
my_lat = np.arange(-67, 67, resolution)

lon = grid_info.hycom_lon.data.ravel()
lat = grid_info.hycom_lat.data.ravel()

altim_lon = nlyss_info.lon.T
altim_lon = altim_lon.where(altim_lon <= 360, altim_lon - 360)
altim_lon = altim_lon.where(altim_lon >= 0, altim_lon + 360)
altim_lat = nlyss_info.lat.T

mgx,mgy = np.meshgrid(mx_lon,my_lat, indexing='ij')
mx = mgx.ravel()
my = mgy.ravel()

def process_file(day,lon, lat, mx, my, new_grid_lon, new_grid_lat,sigma):
    # Zero-pad the day and hour
    day = str(day).zfill(3)
    
    # Construct the file name
    filename = f"216_archs.2018_{day}.nc"
    ddir = "/nobackup/ybadarva/2018/nc/"   # Input Directory
    ds = xr.open_dataset(ddir+filename)
    
    tssh = np.full((24,3305,254),np.nan)
    sssh = np.full((24,3305,254),np.nan)
    fssh = np.full((24,3305,254),np.nan)

    for hour in range(24):    
        ts_rav = ds.srfhgt[hour].T.values.ravel()
        ss_rav = ds.steric[hour].T.values.ravel()
        
        # HYCOM grid to evenly-spaced grid
        # tssh
        mesh1 = pyinterp.RTree()
        mesh1.packing(np.vstack((lon, lat)).T, ts_rav)
        tssh_rg, _ = mesh1.inverse_distance_weighting(
            np.vstack((mx, my)).T,
            within=False,  # Extrapolation is forbidden
            k=11,  # We are looking for at most 11 neighbors
            num_threads=0)
        tssh_rg = tssh_rg.reshape(mx.shape)
        tssh_eg = xr.DataArray(tssh_rg.reshape(10000,3723))
        
        # sssh
        mesh2 = pyinterp.RTree()
        mesh2.packing(np.vstack((lon, lat)).T, ss_rav)
        sssh_rg, _ = mesh2.inverse_distance_weighting(
            np.vstack((mx, my)).T,
            within=False,  # Extrapolation is forbidden
            k=11,  # We are looking for at most 11 neighbors
            num_threads=0)
        sssh_rg = sssh_rg.reshape(mx.shape)
        sssh_eg = xr.DataArray(sssh_rg.reshape(10000,3723))
        
        
        # Guassian spatial fitering of tssh
        tssh_filt1 = tssh_eg.interpolate_na(dim='dim_0',method='linear').interpolate_na(dim='dim_1',method='linear')
        tssh_filt2 = gaussian_filter(tssh_filt1,sigma)
        tssh_filt = xr.where(tssh_eg.isnull(),np.nan,tssh_filt2)
        
        # Evenly-spaced grid to Altimeter track points
        old_coords = np.vstack((mx, my)).T
        new_coords = np.vstack((new_grid_lon.values.ravel(), new_grid_lat.values.ravel())).T
    
        tree = cKDTree(old_coords)
        _, indices = tree.query(new_coords)
        tssh[hour] = tssh_eg.values.flatten()[indices-1].reshape(3305, 254)
        fssh[hour] = tssh_filt.values.flatten()[indices-1].reshape(3305, 254)
        sssh[hour] = sssh_eg.values.flatten()[indices-1].reshape(3305, 254)
        print('Altimeter done')

    # Create a dictionary of DataArrays
    data_vars = {
        "tssh": (("time","np", "nt"), tssh),
        "fssh": (("time","np", "nt"), fssh),
        "sssh": (("time","np", "nt"), sssh),
    }
    
    # Create a new xarray Dataset for this time step
    ds_new_grid_data = xr.Dataset(
        data_vars=data_vars,
        coords={"time": ds.time, "np": range(3305), "nt": range(254)}, 
    )

    return ds_new_grid_data

if __name__ == '__main__':
    # Ensure the day argument is a three-digit string
    day = str(sys.argv[1]).zfill(3)
    
    combined_ds = process_file(sys.argv[1],lon, lat, mx, my, altim_lon, altim_lat,sigma)
       
    # Write the combined dataset to a single NetCDF file
    ddir = "/nobackup/ybadarva/2018/nc/"   # Ouput Directory
    combined_ds.to_netcdf(ddir + f"alt_216_archs.2018_{day}.nc")
    
# The daily nc files are combined in to zarr file for 'Harmonic analysis'